# Copyright (c) Microsoft. All rights reserved.

"""Unit tests for AgentEntity and entity operations.

Run with: pytest tests/test_entities.py -v
"""

import asyncio
from collections.abc import AsyncIterator, Callable
from datetime import datetime
from typing import Any, TypeVar
from unittest.mock import AsyncMock, Mock, patch

import pytest
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, ChatMessage, ErrorContent, Role
from pydantic import BaseModel

from agent_framework_azurefunctions._durable_agent_state import (
    DurableAgentState,
    DurableAgentStateData,
    DurableAgentStateMessage,
    DurableAgentStateRequest,
    DurableAgentStateTextContent,
)
from agent_framework_azurefunctions._entities import AgentEntity, create_agent_entity
from agent_framework_azurefunctions._models import RunRequest

TFunc = TypeVar("TFunc", bound=Callable[..., Any])


def _role_value(chat_message: DurableAgentStateMessage) -> str:
    """Helper to extract the string role from a ChatMessage."""
    role = getattr(chat_message, "role", None)
    role_value = getattr(role, "value", role)
    if role_value is None:
        return ""
    return str(role_value)


def _agent_response(text: str | None) -> AgentRunResponse:
    """Create an AgentRunResponse with a single assistant message."""
    message = (
        ChatMessage(role="assistant", text=text) if text is not None else ChatMessage(role="assistant", contents=[])
    )
    return AgentRunResponse(messages=[message])


class RecordingCallback:
    """Callback implementation capturing streaming and final responses for assertions."""

    def __init__(self):
        self.stream_mock = AsyncMock()
        self.response_mock = AsyncMock()

    async def on_streaming_response_update(
        self,
        update: AgentRunResponseUpdate,
        context: Any,
    ) -> None:
        await self.stream_mock(update, context)

    async def on_agent_response(self, response: AgentRunResponse, context: Any) -> None:
        await self.response_mock(response, context)


class EntityStructuredResponse(BaseModel):
    answer: float


class TestAgentEntityInit:
    """Test suite for AgentEntity initialization."""

    def test_init_creates_entity(self) -> None:
        """Test that AgentEntity initializes correctly."""
        mock_agent = Mock()

        entity = AgentEntity(mock_agent)

        assert entity.agent == mock_agent
        assert len(entity.state.data.conversation_history) == 0
        assert entity.state.data.extension_data is None
        assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION

    def test_init_stores_agent_reference(self) -> None:
        """Test that the agent reference is stored correctly."""
        mock_agent = Mock()
        mock_agent.name = "TestAgent"

        entity = AgentEntity(mock_agent)

        assert entity.agent.name == "TestAgent"

    def test_init_with_different_agent_types(self) -> None:
        """Test initialization with different agent types."""
        agent1 = Mock()
        agent1.__class__.__name__ = "AzureOpenAIAgent"

        agent2 = Mock()
        agent2.__class__.__name__ = "CustomAgent"

        entity1 = AgentEntity(agent1)
        entity2 = AgentEntity(agent2)

        assert entity1.agent.__class__.__name__ == "AzureOpenAIAgent"
        assert entity2.agent.__class__.__name__ == "CustomAgent"


class TestAgentEntityRunAgent:
    """Test suite for the run_agent operation."""

    async def test_run_agent_executes_agent(self) -> None:
        """Test that run_agent executes the agent."""
        mock_agent = Mock()
        mock_response = _agent_response("Test response")
        mock_agent.run = AsyncMock(return_value=mock_response)

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context, {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-1"}
        )

        # Verify agent.run was called
        mock_agent.run.assert_called_once()
        _, kwargs = mock_agent.run.call_args
        sent_messages: list[Any] = kwargs.get("messages")
        assert len(sent_messages) == 1
        sent_message = sent_messages[0]
        assert isinstance(sent_message, ChatMessage)
        assert getattr(sent_message, "text", None) == "Test message"
        assert getattr(sent_message.role, "value", sent_message.role) == "user"

        # Verify result
        assert isinstance(result, AgentRunResponse)
        assert result.text == "Test response"

    async def test_run_agent_streaming_callbacks_invoked(self) -> None:
        """Ensure streaming updates trigger callbacks and run() is not used."""

        updates = [
            AgentRunResponseUpdate(text="Hello"),
            AgentRunResponseUpdate(text=" world"),
        ]

        async def update_generator() -> AsyncIterator[AgentRunResponseUpdate]:
            for update in updates:
                yield update

        mock_agent = Mock()
        mock_agent.name = "StreamingAgent"
        mock_agent.run_stream = Mock(return_value=update_generator())
        mock_agent.run = AsyncMock(side_effect=AssertionError("run() should not be called when streaming succeeds"))

        callback = RecordingCallback()
        entity = AgentEntity(mock_agent, callback=callback)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context,
            {
                "message": "Tell me something",
                "thread_id": "session-1",
                "correlationId": "corr-stream-1",
            },
        )

        assert isinstance(result, AgentRunResponse)
        assert "Hello" in result.text
        assert callback.stream_mock.await_count == len(updates)
        assert callback.response_mock.await_count == 1
        mock_agent.run.assert_not_called()

        # Validate callback arguments
        stream_calls = callback.stream_mock.await_args_list
        for expected_update, recorded_call in zip(updates, stream_calls, strict=True):
            assert recorded_call.args[0] is expected_update
            context = recorded_call.args[1]
            assert context.agent_name == "StreamingAgent"
            assert context.correlation_id == "corr-stream-1"
            assert context.thread_id == "session-1"
            assert context.request_message == "Tell me something"

        final_call = callback.response_mock.await_args
        assert final_call is not None
        final_response, final_context = final_call.args
        assert final_context.agent_name == "StreamingAgent"
        assert final_context.correlation_id == "corr-stream-1"
        assert final_context.thread_id == "session-1"
        assert final_context.request_message == "Tell me something"
        assert getattr(final_response, "text", "").strip()

    async def test_run_agent_final_callback_without_streaming(self) -> None:
        """Ensure the final callback fires even when streaming is unavailable."""

        mock_agent = Mock()
        mock_agent.name = "NonStreamingAgent"
        mock_agent.run_stream = None
        agent_response = _agent_response("Final response")
        mock_agent.run = AsyncMock(return_value=agent_response)

        callback = RecordingCallback()
        entity = AgentEntity(mock_agent, callback=callback)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context,
            {
                "message": "Hi",
                "thread_id": "session-2",
                "correlationId": "corr-final-1",
            },
        )

        assert isinstance(result, AgentRunResponse)
        assert result.text == "Final response"
        assert callback.stream_mock.await_count == 0
        assert callback.response_mock.await_count == 1

        final_call = callback.response_mock.await_args
        assert final_call is not None
        assert final_call.args[0] is agent_response
        final_context = final_call.args[1]
        assert final_context.agent_name == "NonStreamingAgent"
        assert final_context.correlation_id == "corr-final-1"
        assert final_context.thread_id == "session-2"
        assert final_context.request_message == "Hi"

    async def test_run_agent_updates_conversation_history(self) -> None:
        """Test that run_agent updates the conversation history."""
        mock_agent = Mock()
        mock_response = _agent_response("Agent response")
        mock_agent.run = AsyncMock(return_value=mock_response)

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        await entity.run_agent(
            mock_context, {"message": "User message", "thread_id": "conv-1", "correlationId": "corr-entity-2"}
        )

        # Should have 1 entry: user message + assistant response
        user_history = entity.state.data.conversation_history[0].messages
        assistant_history = entity.state.data.conversation_history[1].messages

        assert len(user_history) == 1

        user_msg = user_history[0]
        assert _role_value(user_msg) == "user"
        assert user_msg.text == "User message"

        assistant_msg = assistant_history[0]
        assert _role_value(assistant_msg) == "assistant"
        assert assistant_msg.text == "Agent response"

    async def test_run_agent_increments_message_count(self) -> None:
        """Test that run_agent increments the message count."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        assert len(entity.state.data.conversation_history) == 0

        await entity.run_agent(
            mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-3a"}
        )
        assert len(entity.state.data.conversation_history) == 2

        await entity.run_agent(
            mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-3b"}
        )
        assert len(entity.state.data.conversation_history) == 4

        await entity.run_agent(
            mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-3c"}
        )
        assert len(entity.state.data.conversation_history) == 6

    async def test_run_agent_with_none_thread_id(self) -> None:
        """Test run_agent with a None thread identifier."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        with pytest.raises(ValueError, match="thread_id"):
            await entity.run_agent(
                mock_context, {"message": "Message", "thread_id": None, "correlationId": "corr-entity-5"}
            )

    async def test_run_agent_multiple_conversations(self) -> None:
        """Test that run_agent maintains history across multiple messages."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        # Send multiple messages
        await entity.run_agent(
            mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-8a"}
        )
        await entity.run_agent(
            mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-8b"}
        )
        await entity.run_agent(
            mock_context, {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-8c"}
        )

        history = entity.state.data.conversation_history
        assert len(history) == 6
        assert entity.state.message_count == 6


class TestAgentEntityReset:
    """Test suite for the reset operation."""

    def test_reset_clears_conversation_history(self) -> None:
        """Test that reset clears the conversation history."""
        mock_agent = Mock()
        entity = AgentEntity(mock_agent)

        # Add some history with proper DurableAgentStateEntry objects
        entity.state.data.conversation_history = [
            DurableAgentStateRequest(
                correlation_id="test-1",
                created_at=datetime.now(),
                messages=[
                    DurableAgentStateMessage(
                        role="user",
                        contents=[DurableAgentStateTextContent(text="msg1")],
                    )
                ],
            ),
        ]

        mock_context = Mock()
        entity.reset(mock_context)

        assert entity.state.data.conversation_history == []

    def test_reset_with_extension_data(self) -> None:
        """Test that reset works when entity has extension data."""
        mock_agent = Mock()
        entity = AgentEntity(mock_agent)

        # Set up some initial state with conversation history
        entity.state.data = DurableAgentStateData(conversation_history=[], extension_data={"some_key": "some_value"})

        mock_context = Mock()
        entity.reset(mock_context)

        assert len(entity.state.data.conversation_history) == 0

    def test_reset_clears_message_count(self) -> None:
        """Test that reset clears the message count."""
        mock_agent = Mock()
        entity = AgentEntity(mock_agent)

        mock_context = Mock()
        entity.reset(mock_context)

        assert len(entity.state.data.conversation_history) == 0

    async def test_reset_after_conversation(self) -> None:
        """Test reset after a full conversation."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        # Have a conversation
        await entity.run_agent(
            mock_context, {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-10a"}
        )
        await entity.run_agent(
            mock_context, {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-10b"}
        )

        # Verify state before reset
        assert entity.state.message_count == 4
        assert len(entity.state.data.conversation_history) == 4

        # Reset
        entity.reset(mock_context)

        # Verify state after reset
        assert entity.state.message_count == 0
        assert len(entity.state.data.conversation_history) == 0


class TestCreateAgentEntity:
    """Test suite for the create_agent_entity factory function."""

    def test_create_agent_entity_returns_callable(self) -> None:
        """Test that create_agent_entity returns a callable."""
        mock_agent = Mock()

        entity_function = create_agent_entity(mock_agent)

        assert callable(entity_function)

    def test_entity_function_handles_run_agent(self) -> None:
        """Test that the entity function handles the run_agent operation."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity_function = create_agent_entity(mock_agent)

        # Mock context
        mock_context = Mock()
        mock_context.operation_name = "run_agent"
        mock_context.get_input.return_value = {
            "message": "Test message",
            "thread_id": "conv-123",
            "correlationId": "corr-entity-factory",
        }
        mock_context.get_state.return_value = None

        # Execute
        entity_function(mock_context)

        # Verify result and state were set
        assert mock_context.set_result.called
        assert mock_context.set_state.called

    def test_entity_function_handles_reset(self) -> None:
        """Test that the entity function handles the reset operation."""
        mock_agent = Mock()

        entity_function = create_agent_entity(mock_agent)

        # Mock context with existing state
        mock_context = Mock()
        mock_context.operation_name = "reset"
        mock_context.get_state.return_value = {
            "schemaVersion": "1.0.0",
            "data": {
                "conversationHistory": [
                    {
                        "$type": "request",
                        "correlationId": "test-correlation-id",
                        "createdAt": "2024-01-01T00:00:00Z",
                        "messages": [
                            {
                                "role": "user",
                                "contents": [{"$type": "text", "text": "test"}],
                            }
                        ],
                    }
                ]
            },
        }

        # Execute
        entity_function(mock_context)

        # Verify reset result
        assert mock_context.set_result.called
        result = mock_context.set_result.call_args[0][0]
        assert result["status"] == "reset"

        # Verify state was cleared
        assert mock_context.set_state.called
        state = mock_context.set_state.call_args[0][0]
        assert state["data"]["conversationHistory"] == []

    def test_entity_function_handles_unknown_operation(self) -> None:
        """Test that the entity function handles unknown operations."""
        mock_agent = Mock()

        entity_function = create_agent_entity(mock_agent)

        mock_context = Mock()
        mock_context.operation_name = "invalid_operation"
        mock_context.get_state.return_value = None

        # Execute
        entity_function(mock_context)

        # Verify error result
        assert mock_context.set_result.called
        result = mock_context.set_result.call_args[0][0]
        assert "error" in result
        assert "invalid_operation" in result["error"].lower()

    def test_entity_function_creates_new_entity_on_first_call(self) -> None:
        """Test that the entity function creates a new entity when no state exists."""
        mock_agent = Mock()
        mock_agent.__class__.__name__ = "Agent"

        entity_function = create_agent_entity(mock_agent)
        mock_context = Mock()
        mock_context.operation_name = "reset"
        mock_context.get_state.return_value = None  # No existing state

        # Execute
        entity_function(mock_context)

        # Verify new entity state was created
        assert mock_context.set_result.called
        result = mock_context.set_result.call_args[0][0]
        assert result["status"] == "reset"
        assert mock_context.set_state.called
        state = mock_context.set_state.call_args[0][0]
        assert state["data"] == {"conversationHistory": []}

    def test_entity_function_restores_existing_state(self) -> None:
        """Test that the entity function restores existing state."""
        mock_agent = Mock()

        entity_function = create_agent_entity(mock_agent)

        existing_state = {
            "schemaVersion": "1.0.0",
            "data": {
                "conversationHistory": [
                    {
                        "$type": "request",
                        "correlationId": "corr-existing-1",
                        "createdAt": "2024-01-01T00:00:00Z",
                        "messages": [
                            {
                                "role": "user",
                                "contents": [
                                    {
                                        "$type": "text",
                                        "text": "msg1",
                                    }
                                ],
                            }
                        ],
                    },
                    {
                        "$type": "response",
                        "correlationId": "corr-existing-1",
                        "createdAt": "2024-01-01T00:05:00Z",
                        "messages": [
                            {
                                "role": "assistant",
                                "contents": [
                                    {
                                        "$type": "text",
                                        "text": "resp1",
                                    }
                                ],
                            }
                        ],
                    },
                ],
            },
        }

        mock_context = Mock()
        mock_context.operation_name = "reset"
        mock_context.get_state.return_value = existing_state

        with patch.object(DurableAgentState, "from_dict", wraps=DurableAgentState.from_dict) as from_dict_mock:
            entity_function(mock_context)

        from_dict_mock.assert_called_once_with(existing_state)


class TestErrorHandling:
    """Test suite for error handling in entities."""

    async def test_run_agent_handles_agent_exception(self) -> None:
        """Test that run_agent handles agent exceptions."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(side_effect=Exception("Agent failed"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-1"}
        )

        assert isinstance(result, AgentRunResponse)
        assert len(result.messages) == 1
        content = result.messages[0].contents[0]
        assert isinstance(content, ErrorContent)
        assert "Agent failed" in (content.message or "")
        assert content.error_code == "Exception"

    async def test_run_agent_handles_value_error(self) -> None:
        """Test that run_agent handles ValueError instances."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(side_effect=ValueError("Invalid input"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-2"}
        )

        assert isinstance(result, AgentRunResponse)
        assert len(result.messages) == 1
        content = result.messages[0].contents[0]
        assert isinstance(content, ErrorContent)
        assert content.error_code == "ValueError"
        assert "Invalid input" in str(content.message)

    async def test_run_agent_handles_timeout_error(self) -> None:
        """Test that run_agent handles TimeoutError instances."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(side_effect=TimeoutError("Request timeout"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-error-3"}
        )

        assert isinstance(result, AgentRunResponse)
        assert len(result.messages) == 1
        content = result.messages[0].contents[0]
        assert isinstance(content, ErrorContent)
        assert content.error_code == "TimeoutError"

    def test_entity_function_handles_exception_in_operation(self) -> None:
        """Test that the entity function handles exceptions gracefully."""
        mock_agent = Mock()

        entity_function = create_agent_entity(mock_agent)

        mock_context = Mock()
        mock_context.operation_name = "run_agent"
        mock_context.get_input.side_effect = Exception("Input error")
        mock_context.get_state.return_value = None

        # Execute - should not raise
        entity_function(mock_context)

        # Verify error was set
        assert mock_context.set_result.called
        result = mock_context.set_result.call_args[0][0]
        assert "error" in result

    async def test_run_agent_preserves_message_on_error(self) -> None:
        """Test that run_agent preserves message information on error."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(side_effect=Exception("Error"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        result = await entity.run_agent(
            mock_context,
            {"message": "Test message", "thread_id": "conv-123", "correlationId": "corr-entity-error-4"},
        )

        # Even on error, message info should be preserved
        assert isinstance(result, AgentRunResponse)
        assert len(result.messages) == 1
        content = result.messages[0].contents[0]
        assert isinstance(content, ErrorContent)


class TestConversationHistory:
    """Test suite for conversation history tracking."""

    async def test_conversation_history_has_timestamps(self) -> None:
        """Test that conversation history entries include timestamps."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        await entity.run_agent(
            mock_context, {"message": "Message", "thread_id": "conv-1", "correlationId": "corr-entity-history-1"}
        )

        # Check both user and assistant messages have timestamps
        for entry in entity.state.data.conversation_history:
            timestamp = entry.created_at
            assert timestamp is not None
            # Verify timestamp is in ISO format
            datetime.fromisoformat(str(timestamp))

    async def test_conversation_history_ordering(self) -> None:
        """Test that conversation history maintains the correct order."""
        mock_agent = Mock()

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        # Send multiple messages with different responses
        mock_agent.run = AsyncMock(return_value=_agent_response("Response 1"))
        await entity.run_agent(
            mock_context,
            {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-2a"},
        )

        mock_agent.run = AsyncMock(return_value=_agent_response("Response 2"))
        await entity.run_agent(
            mock_context,
            {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-2b"},
        )

        mock_agent.run = AsyncMock(return_value=_agent_response("Response 3"))
        await entity.run_agent(
            mock_context,
            {"message": "Message 3", "thread_id": "conv-1", "correlationId": "corr-entity-history-2c"},
        )

        # Verify order
        history = entity.state.data.conversation_history
        # Each conversation turn creates 2 entries: request and response
        assert history[0].messages[0].text == "Message 1"  # Request 1
        assert history[1].messages[0].text == "Response 1"  # Response 1
        assert history[2].messages[0].text == "Message 2"  # Request 2
        assert history[3].messages[0].text == "Response 2"  # Response 2
        assert history[4].messages[0].text == "Message 3"  # Request 3
        assert history[5].messages[0].text == "Response 3"  # Response 3

    async def test_conversation_history_role_alternation(self) -> None:
        """Test that conversation history alternates between user and assistant roles."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        await entity.run_agent(
            mock_context,
            {"message": "Message 1", "thread_id": "conv-1", "correlationId": "corr-entity-history-3a"},
        )
        await entity.run_agent(
            mock_context,
            {"message": "Message 2", "thread_id": "conv-1", "correlationId": "corr-entity-history-3b"},
        )

        # Check role alternation
        history = entity.state.data.conversation_history
        # Each conversation turn creates 2 entries: request and response
        assert history[0].messages[0].role == "user"  # Request 1
        assert history[1].messages[0].role == "assistant"  # Response 1
        assert history[2].messages[0].role == "user"  # Request 2
        assert history[3].messages[0].role == "assistant"  # Response 2


class TestRunRequestSupport:
    """Test suite for RunRequest support in entities."""

    async def test_run_agent_with_run_request_object(self) -> None:
        """Test run_agent with a RunRequest object."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        request = RunRequest(
            message="Test message",
            thread_id="conv-123",
            role=Role.USER,
            enable_tool_calls=True,
            correlation_id="corr-runreq-1",
        )

        result = await entity.run_agent(mock_context, request)

        assert isinstance(result, AgentRunResponse)
        assert result.text == "Response"

    async def test_run_agent_with_dict_request(self) -> None:
        """Test run_agent with a dictionary request."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        request_dict = {
            "message": "Test message",
            "thread_id": "conv-456",
            "role": "system",
            "enable_tool_calls": False,
            "correlationId": "corr-runreq-2",
        }

        result = await entity.run_agent(mock_context, request_dict)

        assert isinstance(result, AgentRunResponse)
        assert result.text == "Response"

    async def test_run_agent_with_string_raises_without_correlation(self) -> None:
        """Test that run_agent rejects legacy string input without correlation ID."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        with pytest.raises(ValueError):
            await entity.run_agent(mock_context, "Simple message")

    async def test_run_agent_stores_role_in_history(self) -> None:
        """Test that run_agent stores the role in conversation history."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        # Send as system role
        request = RunRequest(
            message="System message",
            thread_id="conv-runreq-3",
            role=Role.SYSTEM,
            correlation_id="corr-runreq-3",
        )

        await entity.run_agent(mock_context, request)

        # Check that system role was stored
        history = entity.state.data.conversation_history
        assert history[0].messages[0].role == "system"
        assert history[0].messages[0].text == "System message"

    async def test_run_agent_with_response_format(self) -> None:
        """Test run_agent with a JSON response format."""
        mock_agent = Mock()
        # Return JSON response
        mock_agent.run = AsyncMock(return_value=_agent_response('{"answer": 42}'))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        request = RunRequest(
            message="What is the answer?",
            thread_id="conv-runreq-4",
            response_format=EntityStructuredResponse,
            correlation_id="corr-runreq-4",
        )

        result = await entity.run_agent(mock_context, request)

        assert isinstance(result, AgentRunResponse)
        assert result.text == '{"answer": 42}'
        assert result.value is None

    async def test_run_agent_disable_tool_calls(self) -> None:
        """Test run_agent with tool calls disabled."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity = AgentEntity(mock_agent)
        mock_context = Mock()

        request = RunRequest(
            message="Test", thread_id="conv-runreq-5", enable_tool_calls=False, correlation_id="corr-runreq-5"
        )

        result = await entity.run_agent(mock_context, request)

        assert isinstance(result, AgentRunResponse)
        # Agent should have been called (tool disabling is framework-dependent)
        mock_agent.run.assert_called_once()

    async def test_entity_function_with_run_request_dict(self) -> None:
        """Test that the entity function handles the RunRequest dict format."""
        mock_agent = Mock()
        mock_agent.run = AsyncMock(return_value=_agent_response("Response"))

        entity_function = create_agent_entity(mock_agent)

        mock_context = Mock()
        mock_context.operation_name = "run_agent"
        mock_context.get_input.return_value = {
            "message": "Test message",
            "thread_id": "conv-789",
            "role": "user",
            "enable_tool_calls": True,
            "correlationId": "corr-runreq-6",
        }
        mock_context.get_state.return_value = None

        await asyncio.to_thread(entity_function, mock_context)

        # Verify result was set
        assert mock_context.set_result.called
        result = mock_context.set_result.call_args[0][0]
        assert isinstance(result, dict)

        # Check if messages are present
        assert "messages" in result
        assert len(result["messages"]) > 0
        message = result["messages"][0]

        # Check for text in various possible locations
        text_found = False
        if "text" in message and message["text"] == "Response":
            text_found = True
        elif "contents" in message:
            for content in message["contents"]:
                if isinstance(content, dict) and content.get("text") == "Response":
                    text_found = True
                    break

        assert text_found, f"Response text not found in message: {message}"


class TestDurableAgentStateRequestOrchestrationId:
    """Test suite for DurableAgentStateRequest orchestration_id field."""

    def test_request_with_orchestration_id(self) -> None:
        """Test creating a request with an orchestration_id."""
        request = DurableAgentStateRequest(
            correlation_id="corr-123",
            created_at=datetime.now(),
            messages=[
                DurableAgentStateMessage(
                    role="user",
                    contents=[DurableAgentStateTextContent(text="test")],
                )
            ],
            orchestration_id="orch-456",
        )

        assert request.orchestration_id == "orch-456"

    def test_request_to_dict_includes_orchestration_id(self) -> None:
        """Test that to_dict includes orchestrationId when set."""
        request = DurableAgentStateRequest(
            correlation_id="corr-123",
            created_at=datetime.now(),
            messages=[
                DurableAgentStateMessage(
                    role="user",
                    contents=[DurableAgentStateTextContent(text="test")],
                )
            ],
            orchestration_id="orch-789",
        )

        data = request.to_dict()

        assert "orchestrationId" in data
        assert data["orchestrationId"] == "orch-789"

    def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
        """Test that to_dict excludes orchestrationId when not set."""
        request = DurableAgentStateRequest(
            correlation_id="corr-123",
            created_at=datetime.now(),
            messages=[
                DurableAgentStateMessage(
                    role="user",
                    contents=[DurableAgentStateTextContent(text="test")],
                )
            ],
        )

        data = request.to_dict()

        assert "orchestrationId" not in data

    def test_request_from_dict_with_orchestration_id(self) -> None:
        """Test from_dict correctly parses orchestrationId."""
        data = {
            "$type": "request",
            "correlationId": "corr-123",
            "createdAt": "2024-01-01T00:00:00Z",
            "messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
            "orchestrationId": "orch-from-dict",
        }

        request = DurableAgentStateRequest.from_dict(data)

        assert request.orchestration_id == "orch-from-dict"

    def test_request_from_run_request_with_orchestration_id(self) -> None:
        """Test from_run_request correctly transfers orchestration_id."""
        run_request = RunRequest(
            message="test message",
            correlation_id="corr-run",
            orchestration_id="orch-from-run-request",
        )

        durable_request = DurableAgentStateRequest.from_run_request(run_request)

        assert durable_request.orchestration_id == "orch-from-run-request"

    def test_request_from_run_request_without_orchestration_id(self) -> None:
        """Test from_run_request correctly handles missing orchestration_id."""
        run_request = RunRequest(
            message="test message",
            correlation_id="corr-run",
        )

        durable_request = DurableAgentStateRequest.from_run_request(run_request)

        assert durable_request.orchestration_id is None


if __name__ == "__main__":
    pytest.main([__file__, "-v", "--tb=short"])
